package vip.shuai7boy.trafficTemp.areaRoadFlow;


import com.alibaba.fastjson.JSONObject;
import org.apache.commons.collections.IteratorUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;
import vip.shuai7boy.trafficTemp.conf.ConfigurationManager;
import vip.shuai7boy.trafficTemp.constant.Constants;
import vip.shuai7boy.trafficTemp.dao.ITaskDAO;
import vip.shuai7boy.trafficTemp.dao.factory.DAOFactory;
import vip.shuai7boy.trafficTemp.domain.Task;
import vip.shuai7boy.trafficTemp.util.DateUtils;
import vip.shuai7boy.trafficTemp.util.NumberUtils;
import vip.shuai7boy.trafficTemp.util.ParamUtils;
import vip.shuai7boy.trafficTemp.util.SparkUtils;
import vip.spark.spark.test.MockData;

import java.util.*;

/**
 * 道路转换率分析
 */
public class MonitorOneStepConvertRateAnalyze {


    public static void main(String[] args) {

        //判断应用是否在本地运行
        JavaSparkContext sc = null;
        SparkSession spark = null;
        Boolean onLocal = ConfigurationManager.getBoolean(Constants.SPARK_LOCAL);
        if (onLocal) {
            SparkConf conf = new SparkConf().setAppName(Constants.SPARK_APP_NAME).setMaster(Constants.LOCAL);
            sc = new JavaSparkContext(conf);
            spark = SparkSession.builder().getOrCreate();
            MockData.mock(sc, spark);

        } else {
            System.out.println("++++++++++++++++++++++++++++++++++++++开启hive的支持");
            /**
             * select *from table1 join table2 on(连接条件);如果一个表的大小小于20G，则自动广播出去。
             * 会将小于spark.sql.autoBroadcastJoinThreshold值（默认10M）的表广播道executor节点，不走shuffler过程，更加高效。
             * config("spark.sql.autoBroadcastJoinThreshold", "1048576000")单位字节。
             */
            spark =
                    SparkSession.builder().config("spark.sql.autoBroadcastJoinThreshold", "1048576000")
                            .enableHiveSupport().getOrCreate();

            spark.sql("use traffic");
            sc = new JavaSparkContext(spark.sparkContext());

        }
        //获取任务ID
        long taskid = ParamUtils.getTaskIdFromArgs(args, Constants.SPARK_LOCAL_TASKID_MONITOR_ONE_STEP_CONVERT);

        ITaskDAO taskDAO = DAOFactory.getTaskDAO();
        //从MySql查找对应任务
        Task task = taskDAO.findTaskById(taskid);
        if (task == null) {
            return;
        }
        //将参数转换为JSON
        JSONObject taskParam = JSONObject.parseObject(task.getTaskParams());
        System.out.println("taskParam:" + taskParam.toJSONString());
        /**
         * 从数据库中拿出我们要对比转换率的路段
         */
        String roadFlow = ParamUtils.getParam(taskParam, Constants.PARAM_MONITOR_FLOW);
        System.out.println("roadFlow:" + roadFlow);
        //将路段广播出去
        Broadcast<String> roadFlowBroadcast = sc.broadcast(roadFlow);
        /**
         * 从日志拿出指定日期的监控数据
         */
        JavaRDD<Row> rowRDDByDateRange = SparkUtils.getCameraRDDByDateRange(spark, taskParam);

        /**
         * 将每行数据转换为键值对 （car，row） 格式  1
         */
        JavaPairRDD<String, Row> car2RowRDD = getCar2RowRDD(rowRDDByDateRange);
        /**
         *计算每个路段的匹配情况 2
         */
        JavaPairRDD<String, Long> roadSplitRDD = generateAndMatchRowSplit(taskParam, roadFlowBroadcast,
                car2RowRDD);
        /**
         * 所有相同的key先聚合得到总数 3
         */
        Map<String, Long> roadFlow2Count = getRoadFlowCount(roadSplitRDD);
        /**
         * 计算转换率 4
         */
        Map<String, Double> convertRateMap = computerRoadSplitConvertRate(roadFlowBroadcast, roadFlow2Count);
        for (Map.Entry<String, Double> entry : convertRateMap.entrySet()) {
            System.out.println(entry.getKey() + "=" + entry.getValue());
        }

    }


    /**
     * 求出路段转换路
     * @param roadFlow
     * @param roadFlow2Count
     * @return
     */
    private static Map<String, Double> computerRoadSplitConvertRate(Broadcast<String> roadFlow,
                                                                    Map<String, Long> roadFlow2Count) {


        Map<String, Double> rateMap = new HashMap<>();
        String[] split = roadFlow.value().split(",");

        long lastMonitorCarCount = 0L;
        String tmpRoadFlow = "";
        for (int i = 0; i < split.length; i++) {
            tmpRoadFlow += "," + split[i];
            Long count = roadFlow2Count.get(tmpRoadFlow.substring(1));

            if (count != null && count != 0) {
                if (i != 0 && lastMonitorCarCount != 0L) {
                    double rate = NumberUtils.formatDouble((double) count / (double) lastMonitorCarCount, 2);
                    rateMap.put(tmpRoadFlow.substring(1), rate);
                }
                lastMonitorCarCount = count;
            }

        }
        return rateMap;
    }

    /**
     * 将（路段，匹配次数）进行聚合
     * @param roadSplitRDD
     * @return
     */
    private static Map<String, Long> getRoadFlowCount(JavaPairRDD<String, Long> roadSplitRDD) {

        JavaPairRDD<String, Long> sumByKey = roadSplitRDD.reduceByKey(new Function2<Long, Long, Long>() {

            private static final long serialVersionUID = 1L;

            @Override
            public Long call(Long v1, Long v2) throws Exception {
                return v1 + v2;
            }
        });

        Map<String, Long> map = sumByKey.collectAsMap();
        return map;
    }


    /**
     * 计算每个路段的匹配情况
     * （将日志按照时间进行排序，然后和我们的指定的路段进行匹配，最后得到（路段，匹配数量））
     * @param taskParam
     * @param roadFlowBroadcast
     * @param car2RowRDD
     * @return
     */
    private static JavaPairRDD<String, Long> generateAndMatchRowSplit(JSONObject taskParam,
                                                                      final Broadcast<String> roadFlowBroadcast,
                                                                      JavaPairRDD<String, Row> car2RowRDD) {
        return car2RowRDD.groupByKey().flatMapToPair(new PairFlatMapFunction<Tuple2<String, Iterable<Row>>, String, Long>() {

            private static final long serialVersionUID = 1L;

            @Override
            public Iterator<Tuple2<String, Long>> call(Tuple2<String, Iterable<Row>> tuple) throws Exception {
                String car = tuple._1;
                Iterator<Row> iterator = tuple._2.iterator();
                List<Tuple2<String, Long>> list = new ArrayList<>();
                List<Row> rows = IteratorUtils.toList(iterator);

                /**
                 * 对这个Row集合 按照车辆通过时间进行排序
                 */
                Collections.sort(rows, new Comparator<Row>() {
                    @Override
                    public int compare(Row row1, Row row2) {
                        String actionTime1 = row1.getString(4);
                        String actionTime2 = row2.getString(4);
                        return DateUtils.after(actionTime1, actionTime2) ? 1 : -1;
                    }
                });

                //将车辆轨迹以逗号进行拼接
                StringBuilder roadFlowBuilder = new StringBuilder();
                for (Row row : rows) {                    
                    roadFlowBuilder.append("," + row.getAs("monitor_id"));
                }
                String carTracker = roadFlowBuilder.toString().substring(1);

                String standardRoadFlow = roadFlowBroadcast.value();
                //切分段，然后一段一段和上面的值进行比较
                String[] split = standardRoadFlow.split(",");
                for (int i = 1; i <= split.length; i++) {
                    String tmpRoadFlow = "";
                    for (int j = 0; j < i; j++) {
                        tmpRoadFlow += "," + split[j];
                    }
                    tmpRoadFlow = tmpRoadFlow.substring(1);//去掉前面的逗号

                    int index = 0;
                    Long count = 0L;
                    while (carTracker.indexOf(tmpRoadFlow, index) != -1) {
                        index = carTracker.indexOf(tmpRoadFlow, index) + 1;
                        count++;
                    }
                    list.add(new Tuple2<>(tmpRoadFlow, count));
                }

                return list.iterator();
            }
        });

    }


    /**
     * 将每行数据转换为键值对 （car，row） 格式    
     * @param car2RowRDD
     * @return
     */
    private static JavaPairRDD<String, Row> getCar2RowRDD(JavaRDD<Row> car2RowRDD) {
        return car2RowRDD.mapToPair(new PairFunction<Row, String, Row>() {
            
            private static final long serialVersionUID = 1L;

            @Override
            public Tuple2<String, Row> call(Row row) throws Exception {

                return new Tuple2<>(row.getAs("car"), row);
            }
        });
    }


}
